import os
import logging
import pickle
import json
import argparse
import itertools
import time
from typing import Union
from multiprocessing import Pool, cpu_count
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from pyomo.environ import *

executable_path = r"path/to/ipopt.exe"  # Update this path to your IPOPT executable
if not os.path.exists(executable_path):
    executable_path = None
if executable_path:
    solver = SolverFactory('ipopt', executable=executable_path)
else:
    solver = SolverFactory('ipopt')
solver.options['max_iter'] = 10000
solver.options['halt_on_ampl_error'] = 'yes'


def sample_mdp(S, A, n, p):
    """
    for each state s and action a, sample n times from Bernoulli distribution with p
    """
    P = np.zeros(shape=(S, A))
    for s in range(S):
        for a in range(A):
            P[s, a] = np.random.binomial(n, p) / n
    return P


def kl_divergence(p, p_nominal):
    r"""
    KL divergence constraint
    \sum_{s^\prime} P(s^\prime|s,a) \log \frac{P(s^\prime|s,a)}{P(s^\prime|s,a)_{nominal}} \leq \delta
    """
    return p * np.log(p / p_nominal) + (1 - p) * np.log((1 - p) / (1 - p_nominal))


def f_divergence(p, p_nominal, k):
    r"""
    f-divergence constraint
    \sum_{s^\prime} \frac{t_{s^\prime}^k-t_{s^\prime}k+k-1}{k(k-1)}
    where t_{s^\prime} = \frac{P(s^\prime|s,a)}{P(s^\prime|s,a)_{nominal}}
    """
    t1 = p / p_nominal
    t2 = (1 - p) / (1 - p_nominal)
    return (t1 ** k - k * t1 + k - 1) / (k * (k - 1)) + (t2 ** k - k * t2 + k - 1) / (k * (k - 1))


def solve_general_mdp(p_nominal: Union[np.ndarray, float], gamma: float, delta: float, divergence: str, k: float = 2.0):
    r"""
    The solution of General MDP can be formulated as:
    v(s) = \frac{1}{1-\gamma P_{s,a}}\\
    {\rm s.t.} D_f(P_{s,a}\|\bar P_{s,a})\leq \delta

    Since we know that if we want to minimize v(s), we need to minimize P_{s,a}.
    We can use bisection to find the optimal P_{s,a}.
    """
    # when p_nominal = 1, the KL divergence is not defined, so we need to handle it separately
    if divergence == 'kl' and p_nominal == 1:
        return 1 / (1 - gamma * p_nominal)
    # using bisection to find the optimal P_{s,a}
    p_low = 0
    p_high = p_nominal
    p = (p_low + p_high) / 2

    while abs(p_high - p_low) > 1e-10:
        if divergence == 'kl':
            div = kl_divergence(p, p_nominal)
        elif divergence == 'f':
            div = f_divergence(p, p_nominal, k)
        else:
            raise ValueError("divergence must be 'kl' or 'f'")
        if div < delta:
            p_high = p
        else:
            p_low = p
        p = (p_low + p_high) / 2
    return 1 / (1 - gamma * p)


def compute_max_value_for_state(A, P_s, gamma, delta, divergence, k):
    """
    use Value Iteration to solve the MDP
    """
    max_v = -float("inf")
    for a in range(A):
        max_v = max(max_v, solve_general_mdp(P_s[a], gamma, delta, divergence, k))
    return max_v


def single_experiment(S, A, n, p_nominal, gamma, delta, divergence, k=2.0):
    if n is None:
        P = np.ones(shape=(S, A)) * p_nominal
    else:
        P = sample_mdp(S, A, n, p_nominal)

    args_list = [(A, P[s], gamma, delta, divergence, k) for s in range(S)]
    with Pool(processes=min(32, cpu_count())) as pool:
        results = list(pool.starmap(compute_max_value_for_state, args_list))
    v = np.array(results)
    return v, P


def all_experiment(S_values, A_values, n_values, runs, gamma, delta, divergence, k=2.0):
    results = {}
    total_iterations = len(S_values) * len(A_values) * len(n_values)
    for S, A, n in tqdm(
            itertools.product(S_values, A_values, n_values),
            total=total_iterations,
            desc="Processing"
    ):
        start_time = time.time()
        if n is None:
            v, P_empirical = single_experiment(S, A, n, p, gamma, delta, divergence)
            results[(S, A, n, 0)] = {"S": S, "A": A, "n": n, "run": 0, "gamma": gamma,
                                     "delta": delta, "divergence": divergence, "k": k,
                                     "v": v}
        else:
            for run in range(runs):
                print(S, A, n, run)
                seed = int(str(S) + str(A) + str(run))
                np.random.seed(seed)
                v, P_empirical = single_experiment(S, A, n, p, gamma, delta, divergence)
                results[(S, A, n, run)] = {"S": S, "A": A, "n": n, "run": run, "gamma": gamma,
                                           "delta": delta, "divergence": divergence, "k": k,
                                           "v": v}  # , "P_empirical": P_empirical}
        end_time = time.time()
        logging.info(f"Time taken for S={S}, A={A}, n={n}: {end_time - start_time:.2f} seconds")
    with open(save_dir + f'/results_S={S_values[0]}_A={A_values[0]}.pkl', 'wb') as f:
        pickle.dump(results, f)
    # show result keys
    result_keys = []
    for key in results.keys():
        result_keys.append((str(key[0]), str(key[1]), str(key[2]), str(key[3])))
    logging.info(f"result_keys: {result_keys}")
    return results


def generate_param_text(S_values, A_values):
    """
    When run code on slurm, generate a text file with the parameters for each job in job array.
    """
    with open('general_mdp_params.txt', 'w') as f:
        for S in S_values:
            for A in A_values:
                f.write(str(S) + ' ' + str(A) + '\n')


if __name__ == '__main__':
    # ------------------------------------------------------------ #
    # Generate parameters for slurm job array
    # Every (S,A) will generate a job, which is 19*19 = 361  jobs in total.
    # For each (S,A), we do a grid search for n values and num_runs.
    # ------------------------------------------------------------ #
    S_values = np.linspace(10, 1000, 19)
    S_values = np.array(S_values, dtype=int)
    A_values = np.linspace(10, 1000, 19)
    A_values = np.array(A_values, dtype=int)
    n_values = [None, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000]
    # generate_param_text(S_values, A_values)

    # ------------------------------------------------------------ #
    # Parse command line arguments
    # ------------------------------------------------------------ #
    parser = argparse.ArgumentParser()
    parser.add_argument('-S', type=int, default=100)
    parser.add_argument('-A', type=int, default=1000)
    args = parser.parse_args()

    # ------------------------------------------------------------ #
    # parameters
    # ------------------------------------------------------------ #
    runs = 10
    p = 0.8  # transition probability
    delta = 0.1  # upper bounds for divergence constraint
    gamma = 0.9  # discount factor
    divergence = 'kl'  # type of divergence
    k = 2.0  # exponent for f-divergence
    S, A = args.S, args.A
    S_values, A_values = [S], [A]
    save_dir = 'results/general_mdp/'
    log_path = "logs/general_mdp/"

    # ------------------------------------------------------------ #
    # logging
    # ------------------------------------------------------------ #
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler(os.path.join(log_path, f"general_mdp_S={S}_A={A}.log")),
            logging.StreamHandler(),
        ]
    )
    logging.info(f"Running with S={S}, A={A}")
    # ------------------------------------------------------------ #
    # grid search for all parameters combinations
    # ------------------------------------------------------------ #
    results = all_experiment(S_values, A_values, n_values, runs, gamma, delta, divergence, k)
